import numpy as np


def im2col(img, kernel_size):
    c, h, w = img.shape[0], img.shape[1] - 2, img.shape[2] - 2
    kh, kw = kernel_size
    y = np.zeros([h * w, c * kh * kw])
    for l in range(c):
        for m in range(kh):
            for n in range(kw):
                y[:, l * kh * kw + m * kw + n] = img[l, m:m + h, n:n + w].flatten()
    return y
